import os
import re
import json
import openai
import numpy as np
from tqdm import tqdm
from nltk import pos_tag
from openai import OpenAI
from nltk.stem import WordNetLemmatizer
from sklearn.metrics.pairwise import cosine_similarity

from utils.util import read_txt, read_json, write_json, is_json, write_txt
from utils.token_count_decorator import token_count_decorator
# from planning.src.metrics import Metrics
from planning.src.protocol import Protocol

class NoveltySeeker:
    def __init__(self, domain) -> None:
        self.domain = domain
        self.corpus_path = f"planning/data/corpus/{domain}/"
        self.candidate_path = f"planning/data/candidate/{domain}/"
        self.origin_path = "dataset/original_protocol/"
        self.name_mapping = {
            "Genetics": "Molecular Biology & Genetics",
            "Medical": "Biomedical & Clinical Research",
            "Ecology": "Ecology & Environmental Biology",
            "BioEng": "Bioengineering & Technology",
        }
        self.pseudofunctions_generation_prompt = read_txt("planning/data/prompt/pseudofunctions_generation.txt")
        self.pseudocode_to_json_prompt = read_txt("planning/data/prompt/pseudocode_to_json.txt")

    def convert_procedures_to_steps(self):
        for filename in tqdm(os.listdir(self.corpus_path)):
            file_path = os.path.join(self.corpus_path, filename)
            protocol = read_json(file_path)
            protocol["steps"] = "\n".join(protocol["procedures"])
            write_json(file_path, protocol)
    
    def generate_pseudocode(self):
        for filename in tqdm(os.listdir(self.corpus_path)):
            file_path = os.path.join(self.corpus_path, filename)
            protocol = read_json(file_path)
            if not protocol.get("generated_pseudocode"):
                pseudocode_prompt = self.pseudofunctions_generation_prompt.replace("{title}", protocol["title"]).replace("{protocol}", protocol["steps"])
                for _ in range(5):
                    response_1 = self.__chatgpt_function(content=pseudocode_prompt)
                    program_1 = re.findall(r'```python([^`]*)```', response_1, re.DOTALL)
                    try:
                        pseudofunctions, pseudocode = program_1[0].split("# Protocol steps", 1)
                        if pseudofunctions and pseudocode:
                            protocol["generated_pseudocode"] = program_1[0].strip()
                            break
                    except:
                        continue
            if not protocol.get("program"):
                program_prompt = self.pseudocode_to_json_prompt.replace("{pseudocode}", protocol["generated_pseudocode"])
                for _ in range(5):
                    response_2 = self.__chatgpt_function(content=program_prompt)
                    program_2 = re.findall(r'```json([^`]*)```', response_2, re.DOTALL)
                    if len(program_2) > 0 and is_json(plan := program_2[0].strip()):
                        protocol["program"] = json.loads(plan)
                        break
            protocol["id"] = str(protocol["id"])
            # print(protocol["id"])
            write_json(file_path, protocol)
    
    def flatten_embedding(self):
        emb_matrix = np.load(f"planning/data/{self.domain}.npy")
        sim_matrix = cosine_similarity(emb_matrix)
        avg_similarities = np.mean(np.where(np.eye(sim_matrix.shape[0], dtype=bool), 0, sim_matrix), axis=1)
        least_similar_indices = np.argsort(avg_similarities)[:100]
        filenames = sorted(os.listdir(self.corpus_path))
        candidates_filename = [filenames[i] for i in least_similar_indices]
        for filename in tqdm(candidates_filename):
            candidate_path = os.path.join(self.candidate_path, filename)
            init_path = os.path.join(self.corpus_path, filename)
            write_json(candidate_path, read_json(init_path))
            os.remove(init_path)
    
    def seek(self):
        self.corpus_embeddings = np.load(f"planning/data/{self.domain}_corpus.npy")
        self.candidate_embeddings = np.load(f"planning/data/{self.domain}_candidate.npy")
        self.corpus_i = {filename: i for i, filename in enumerate(sorted(os.listdir(self.corpus_path)))}
        self.candidate_i = {filename: i for i, filename in enumerate(sorted(os.listdir(self.candidate_path)))}
        novelty = []
        modes = ["planning", "modification", "adjustment"]
        corpus_filenames = sorted(os.listdir(self.corpus_path))
        for candidate_filename in tqdm(sorted(os.listdir(self.candidate_path))):
            candidate_dict = {
                "novel protocol": Protocol.fromjson(read_json(os.path.join(self.candidate_path, candidate_filename))).tojson("title", "description")
            }
            for mode in modes:
                scores = [
                    self.weigher(corpus_filename, candidate_filename, mode)
                    for corpus_filename in tqdm(corpus_filenames, leave=False, desc=mode)
                ]
                closet_corpus = corpus_filenames[scores.index(max(scores))]
                candidate_dict[f"old protocol - {mode}"] = Protocol.fromjson(read_json(os.path.join(self.corpus_path, closet_corpus))).tojson("title", "description")
            candidate_dict["judgement"] = ""
            novelty.append(candidate_dict)
        write_json(f"planning/data/{self.domain}_novelty.json", novelty)

    def weigher(self, corpus_filename, candidate_filename, mode):
        co_program = Protocol.fromjson(read_json(os.path.join(self.corpus_path, corpus_filename)))
        ca_program = Protocol.fromjson(read_json(os.path.join(self.candidate_path, candidate_filename)))
        co_idx = self.corpus_i[corpus_filename]
        ca_idx = self.candidate_i[candidate_filename]
        metrics = Metrics(
            novel_protocol=ca_program,
            groundtruth_protocol=co_program,
            novel_program_type="pseudocode"
        )
        description_cosine = self.cos(self.corpus_embeddings[co_idx], self.candidate_embeddings[ca_idx])
        iou_on_op = metrics.get_dimension_1()
        op_seq_sim = metrics.get_dimension_4()

        if mode == "planning":
            # bleu_score = metrics.get_dimension_6()
            # print(bleu_score)
            return 0.2 * description_cosine + 0.2 * iou_on_op + 0.6 * (op_seq_sim)
        elif mode == "modification":
            # bleu_score = metrics.get_dimension_6()
            return 0.4 * (description_cosine + iou_on_op) + 0.2 * (op_seq_sim)
        elif mode == "adjustment":
            return 0.5 * description_cosine + 0.5 * iou_on_op
        else:
            raise ValueError("Wrong mode!")
    
    def dataset_construct(self):
        candidate_title_id = {}
        for filename in os.listdir(self.candidate_path):
            candidate = read_json(f"{self.candidate_path}{filename}")
            candidate_title_id[candidate["title"]] = candidate["id"]
        print(len(os.listdir(self.candidate_path)), len(candidate_title_id))
        corpus_title_id = {}
        for filename in os.listdir(self.corpus_path):
            cor = read_json(f"{self.corpus_path}{filename}")
            corpus_title_id[cor["title"]] = cor["id"]
        print(len(os.listdir(self.corpus_path)), len(corpus_title_id))

        novel_to_old = {}
        tasks = ["planning", "modification", "adjustment"]
        data_classified = read_json(f"planning/data/Dataset_20240925/{self.domain}_novelty_classified.json")
        for data in tqdm(data_classified):
            if (jud := data.get("judgement", "").strip()) not in ["", "NONE"]:
                id = candidate_title_id[data["novel protocol"]["title"]]
                dataset_path = f"dataset/planning/{self.domain}/{tasks[int(jud)]}/{id}.json"
                gt = read_json(f"{self.candidate_path}{id}.json")
                old_id = corpus_title_id[data[f"old protocol - {tasks[int(jud)]}"]["title"]]
                novel_to_old[id] = old_id
                write_json(dataset_path, gt)
        
        novel_to_old_json = read_json("planning/data/novel_to_old.json")
        novel_to_old_json[self.domain] = novel_to_old
        write_json("planning/data/novel_to_old.json", novel_to_old_json)
    
    def difficulty_cal(self):
        mode_scores = {}
        self.corpus_embeddings = np.load(f"planning/data/{self.domain}_corpus.npy")
        self.candidate_embeddings = np.load(f"planning/data/{self.domain}_candidate.npy")
        self.corpus_i = {filename: i for i, filename in enumerate(sorted(os.listdir(self.corpus_path)))}
        self.candidate_i = {filename: i for i, filename in enumerate(sorted(os.listdir(self.candidate_path)))}
        modes = ["planning", "modification", "adjustment"]
        novel_to_old = read_json("planning/data/novel_to_old.json")[self.domain]
        for mode in modes:
            mode_scores[mode] = []
            for novel_filename in tqdm(os.listdir(f"dataset/planning/{self.domain}/{mode}/"), desc=mode):
                novel_id = read_json(f"dataset/planning/{self.domain}/{mode}/{novel_filename}")["id"]
                old_id = novel_to_old[novel_id]
                score = self.weigher(corpus_filename=f"{old_id}.json", candidate_filename=novel_filename, mode=mode)
                mode_scores[mode].append(score)
        return mode_scores
    
    def difficulty_cal_tot(self):
        mode_scores = {}
        self.corpus_embeddings = np.load(f"planning/data/{self.domain}_corpus.npy")
        self.candidate_embeddings = np.load(f"planning/data/{self.domain}_candidate.npy")
        self.corpus_i = {filename: i for i, filename in enumerate(sorted(os.listdir(self.corpus_path)))}
        self.candidate_i = {filename: i for i, filename in enumerate(sorted(os.listdir(self.candidate_path)))}
        modes = ["planning", "modification", "adjustment"]
        candidate_title_id = {}
        for filename in os.listdir(self.candidate_path):
            candidate = read_json(f"{self.candidate_path}{filename}")
            candidate_title_id[candidate["title"]] = candidate["id"]
        corpus_title_id = {}
        for filename in os.listdir(self.corpus_path):
            cor = read_json(f"{self.corpus_path}{filename}")
            corpus_title_id[cor["title"]] = cor["id"]
        novelty = read_json(f"planning/data/Dataset_20240925/{self.domain}_novelty.json")
        for novel_old_dict in tqdm(novelty):
            novel_id = candidate_title_id[novel_old_dict["novel protocol"]["title"]]
            for mode in modes:
                old_id = corpus_title_id[novel_old_dict[f"old protocol - {mode}"]["title"]]
                score = self.weigher(corpus_filename=f"{old_id}.json", candidate_filename=f"{novel_id}.json", mode=mode)
                mode_scores.setdefault(mode, []).append(score)
        return mode_scores


    def cos(self, emb_a, emb_b):
        try:
            return np.dot(emb_a, emb_b) / (np.linalg.norm(emb_a) * np.linalg.norm(emb_b))
        except Exception as _:
            return 0.0

    @token_count_decorator(flow="together", batch=False)
    def __chatgpt_function(self, content, gpt_model="gpt-4o-mini"):
        while True:
            try:
                client = OpenAI(
                    api_key=os.environ.get("OPENAI_API_KEY"),
                )
                chat_completion = client.chat.completions.create(
                    messages=[
                        {"role": "user", "content": content}
                    ],
                    model=gpt_model
                )
                return chat_completion.choices[0].message.content
            except openai.APIError as error:
                print(error)

class DataPicker:
    def __init__(self, domain) -> None:
        self.domain = domain
        self.dataset_path = f"dataset/planning/{domain}/"
        self.dataset_picked_path = f"dataset/planning_picked/{domain}/"
        self.tasks = ["planning", "modification", "adjustment"]
        self.operation_dsl, self.production_dsl = self.load_dsl()
        self.dataset_metadata_path = "planning/data/dataset_metadata.json"
        self.dataset_metadata = read_json(self.dataset_metadata_path)
        self.lemmatizer = WordNetLemmatizer()
        self.opcodes = [op.lower() for op in self.operation_dsl]
        self.program_components_extraction_prompt = read_txt("planning/data/prompt/program_components_extraction.txt")
        self.device_extraction_prompt = read_txt("planning/data/prompt/device_extraction.txt")
        self.program_devices_extraction_prompt = read_txt("planning/data/prompt/program_devices_extraction.txt")
        self.dsl_gt_generation_prompt = read_txt("planning/data/prompt/dsl_gt_generation.txt")
        self.multi_dsl_gt_generation_prompt = read_txt("planning/data/prompt/multi_dsl_gt_generation.txt")

    def load_dsl(self):
        operation_dsl = read_json(f"dsl_result/{self.domain}/operation_dsl.json")
        production_dsl = read_json(f"dsl_result/{self.domain}/production_dsl.json")
        return operation_dsl, production_dsl
    
    def pick(self):
        for task in self.tasks:
            count = 0
            for filename in tqdm(os.listdir(f"{self.dataset_path}{task}"), desc=task):
                protocol = Protocol.fromjson(read_json(f"{self.dataset_path}{task}/{filename}"))
                operations = self.__get_operations_sequence(protocol.program)
                op_count = len([op for op in operations if op not in self.opcodes])
                if op_count > 7:
                    print([op for op in operations if op not in self.opcodes])
                    continue
                flowunits = self.__get_components(protocol)
                if "NONE" in flowunits:
                    continue
                devices_gt = self.__get_devices(protocol, program_type="groundtruth")
                devices_pc = self.__get_devices(protocol, program_type="pseudocode")
                if not devices_gt and not devices_pc:
                    continue
                devices = max(devices_gt, devices_pc, key=len)
                self.dataset_metadata[self.domain].setdefault(protocol.id, {})["devices"] = devices
                new_dir_path = f"{self.dataset_picked_path}{task}/"
                os.makedirs(new_dir_path, exist_ok=True)
                write_json(f"{new_dir_path}{filename}", read_json(f"{self.dataset_path}{task}/{filename}"))
                count += 1
            print(task, len(os.listdir(f"{self.dataset_path}{task}")), count)
        self.__dump_dataset_metadata()
    
    def dsl_gt_generation(self):
        for task in self.tasks:
            for filename in tqdm(os.listdir(f"{self.dataset_picked_path}{task}"), desc=task):
                protocol_json = read_json(f"{self.dataset_picked_path}{task}/{filename}")
                protocol = Protocol.fromjson(protocol_json)
                dsl_prompt = self.dsl_gt_generation_prompt.replace("{title}", protocol.title).replace("{protocol}", protocol.steps)
                for _ in range(5):
                    print("dsl-try")
                    response = self.__chatgpt_function(dsl_prompt)
                    program = re.findall(r'```json([^`]*)```', response, re.DOTALL)
                    if len(program) > 0 and is_json(plan := program[0].strip()):
                        break
                    else:
                        print(len(program))
                        write_txt("test.txt", response)          
                protocol_json["dsl_program"] = json.loads(plan)
                multi_dsl_prompt = self.multi_dsl_gt_generation_prompt.replace("{title}", protocol.title).replace("{protocol}", protocol.steps)
                for _ in range(5):
                    print("multi-dsl-try")
                    response = self.__chatgpt_function(multi_dsl_prompt)
                    program = re.findall(r'```json([^`]*)```', response, re.DOTALL)
                    if len(program) > 0 and is_json(plan := program[0].strip()):
                        break
                protocol_json["multi_dsl_program"] = json.loads(plan)
                write_json(f"{self.dataset_picked_path}{task}/{filename}", protocol_json)

    def __get_operations_sequence(self, program: dict) -> list:
        operations_sequence = []
        for func_name in program.keys():
            first_verb = self.__get_first_verb(func_name)
            if first_verb:
                operations_sequence.append(first_verb)
        return operations_sequence

    def __get_first_verb(self, operation_str):
        tokens = re.split(r'[_ ]', operation_str)
        lemmatized_tokens = [self.lemmatizer.lemmatize(token, pos="v").lower() for token in tokens]
        pos_tags = pos_tag(lemmatized_tokens)
        
        for word, pos in pos_tags:
            if pos.startswith('VB'):  # VB, VBD, VBG, VBN, VBP, VBZ
                return word
        return lemmatized_tokens[0]
    
    def __get_components(self, protocol: Protocol) -> list:
        flowunits = self.dataset_metadata[self.domain].get(protocol.id, {}).get("flowunits", [])
        if not flowunits:
            prompt = self.program_components_extraction_prompt.replace("---PSEUDOCODE---", json.dumps(protocol.program, indent=4, ensure_ascii=False))
            for _ in range(5):
                response = self.__chatgpt_function(prompt)
                flowunits = [flowunit.strip() for flowunit in response.split(",") if flowunit.strip()]
                if flowunits:
                    self.dataset_metadata[self.domain].setdefault(protocol.id, {})["flowunits"] = flowunits
                    self.__dump_dataset_metadata()
                    break
        return flowunits
    
    def __get_devices(self, protocol: Protocol, program_type: str) -> list:
        devices = []
        if program_type == "dsl":
            for step in protocol.program:
                if "Execution" in step:
                    if isinstance(step["Execution"], dict):
                        devices.append(step["Execution"]["DeviceType"])
                    elif isinstance(step["Execution"], list):
                        devices.extend([device_dict["DeviceType"] for device_dict in step["Execution"]])
            return devices
        
        elif program_type == "pseudocode":
            prompt = self.program_devices_extraction_prompt.replace("---PSEUDOCODE---", json.dumps(protocol.program, indent=4, ensure_ascii=False))
            for _ in range(5):
                response = self.__chatgpt_function(prompt)
                devices = [device.strip() for device in response.split(",") if device.strip()]
                if "NONE" in devices or not devices:
                    return []
                return devices
        
        elif program_type == "groundtruth":
            sentence_list = self.__convert_to_sentence_list(protocol.steps)
            for sentence in sentence_list:
                device_list = self.__device_extraction(sentence)
                if "NONE" not in device_list:
                    devices.extend(device_list)
            return devices

    def __convert_to_sentence_list(self, steps):
        if not steps:
            return []
        sentences = [sentence.strip() for sentence in steps.split("\n") if sentence.strip()]
        operation_steps = [sentence for sentence in sentences if re.match(r'^\d+\.', sentence)]
        return operation_steps
    
    def __device_extraction(self, sentence):
        prompt = self.device_extraction_prompt.replace("---SENTENCES---", sentence)
        for _ in range(5):
            response = self.__chatgpt_function(prompt).strip()
            if "NONE" in response:
                return ["NONE"]
            return [device.strip() for device in response.split(",") if device.strip()]

    @token_count_decorator(flow="together", batch=False)
    def __chatgpt_function(self, content, gpt_model="gpt-4o-mini"):
        while True:
            try:
                client = OpenAI(
                    api_key=os.environ.get("OPENAI_API_KEY"),
                )
                chat_completion = client.chat.completions.create(
                    messages=[
                        {"role": "user", "content": content}
                    ],
                    model=gpt_model
                )
                return chat_completion.choices[0].message.content
            except openai.APIError as error:
                print(error)

    def __dump_dataset_metadata(self):
        write_json(self.dataset_metadata_path, self.dataset_metadata)